{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Speed comparison of gradient boosting libraries for shap values calculations"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here we compare CatBoost, LightGBM and XGBoost for shap values calculations. All boosting algorithms were trained on GPU but shap evaluation was on CPU.\n",
    "\n",
    "We use the epsilon_normalized dataset from [here](https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "import datetime\n",
    "import os\n",
    "\n",
    "import catboost\n",
    "import lightgbm as lgb\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import tqdm\n",
    "import xgboost as xgb\n",
    "from sklearn import datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "('0.11.2', '2.2.2', '0.81')"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "catboost.__version__, lgb.__version__, xgb.__version__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data, train_target = datasets.load_svmlight_file(\"epsilon_normalized\")\n",
    "test_data, test_target = datasets.load_svmlight_file(\n",
    "    \"epsilon_normalized.t\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_iters = 1000\n",
    "lr = 0.1\n",
    "max_bin = 128\n",
    "gpu_device = \"0\"  # specify your GPU (used only for training)\n",
    "random_state = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_target[train_target == -1] = 0\n",
    "test_target[test_target == -1] = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def preprocess_data(data, label=None, mode=\"train\", boosting=None):\n",
    "    assert boosting is not None\n",
    "\n",
    "    if boosting == \"xgboost\":\n",
    "        return xgb.DMatrix(data, label)\n",
    "    elif boosting == \"lightgbm\":\n",
    "        if mode == \"train\":\n",
    "            return lgb.Dataset(data, label)\n",
    "        else:\n",
    "            return data\n",
    "    elif boosting == \"catboost\":\n",
    "        data = catboost.FeaturesData(num_feature_data=data)\n",
    "        return catboost.Pool(data, label)\n",
    "    else:\n",
    "        raise RuntimeError(\"Unknown boosting library\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_parameters(base_params, boosting=None, **kwargs):\n",
    "    assert boosting is not None\n",
    "    assert isinstance(base_params, dict)\n",
    "\n",
    "    params = copy.copy(base_params)\n",
    "    if boosting == \"xgboost\":\n",
    "        params[\"objective\"] = \"binary:logistic\"\n",
    "        params[\"max_depth\"] = kwargs[\"depth\"]\n",
    "        params[\"tree_method\"] = \"gpu_hist\"\n",
    "        params[\"gpu_id\"] = gpu_device\n",
    "    elif boosting == \"lightgbm\":\n",
    "        params[\"objective\"] = \"binary\"\n",
    "        params[\"device\"] = \"gpu\"\n",
    "        params[\"gpu_device_id\"] = gpu_device\n",
    "        params[\"num_leaves\"] = 2 ** kwargs[\"depth\"]\n",
    "    elif boosting == \"catboost\":\n",
    "        params[\"objective\"] = \"Logloss\"\n",
    "        params[\"task_type\"] = \"GPU\"\n",
    "        params[\"devices\"] = gpu_device\n",
    "        params[\"bootstrap_type\"] = \"Bernoulli\"\n",
    "        params[\"logging_level\"] = \"Silent\"\n",
    "    else:\n",
    "        raise RuntimeError(\"Unknown boosting library\")\n",
    "\n",
    "    return params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(data, params, num_iters, boosting=None):\n",
    "    assert boosting is not None\n",
    "    if boosting == \"xgboost\":\n",
    "        return xgb.train(params=params, dtrain=data, num_boost_round=num_iters)\n",
    "    elif boosting == \"lightgbm\":\n",
    "        return lgb.train(params=params, train_set=data, num_boost_round=num_iters)\n",
    "    elif boosting == \"catboost\":\n",
    "        return catboost.train(pool=data, params=params, num_boost_round=num_iters)\n",
    "    else:\n",
    "        raise RuntimeError(\"Unknown boosting library\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def predict_shap(model, data, boosting=None):\n",
    "    assert boosting is not None\n",
    "    if boosting == \"xgboost\":\n",
    "        return model.predict(data, pred_contribs=True)\n",
    "    elif boosting == \"lightgbm\":\n",
    "        return model.predict(data, pred_contrib=True)\n",
    "    elif boosting == \"catboost\":\n",
    "        return model.get_feature_importance(data, fstr_type=\"ShapValues\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_path(boosting, params):\n",
    "    fname = [boosting]\n",
    "    for key, value in sorted(params.items()):\n",
    "        fname.append(str(key))\n",
    "        fname.append(str(value))\n",
    "    fname = \"_\".join(fname)\n",
    "    fname = fname.replace(\".\", \"\")\n",
    "    fname += \".model\"\n",
    "    return fname"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_model(fname, boosting):\n",
    "    if boosting == \"xgboost\":\n",
    "        bst = xgb.Booster(model_file=fname)\n",
    "        bst.load_model(fname)\n",
    "    elif boosting == \"lightgbm\":\n",
    "        bst = lgb.Booster(model_file=fname)\n",
    "    elif boosting == \"catboost\":\n",
    "        bst = catboost.CatBoost()\n",
    "        bst.load_model(fname)\n",
    "    else:\n",
    "        raise RuntimeError(\"Unknown boosting\")\n",
    "    return bst"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "base_params = {\"learning_rate\": lr, \"max_bin\": max_bin, \"random_state\": random_state}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "result = []\n",
    "\n",
    "boosting_list = [\"xgboost\", \"catboost\", \"lightgbm\"]\n",
    "depth_list = [2, 4, 6, 8, 10]\n",
    "lens_list = [1000, 5000, 10000]\n",
    "\n",
    "\n",
    "for gb_type in boosting_list:\n",
    "    print(f\"{gb_type} is going\")\n",
    "\n",
    "    for size_test in lens_list:\n",
    "        print(f\"size test {size_test}\")\n",
    "        sep_test_data = test_data[:size_test]\n",
    "        sep_test_target = test_target[:size_test]\n",
    "\n",
    "        # comment this line if you have already trained all models\n",
    "        train_preprocessed = preprocess_data(train_data, train_target, boosting=gb_type)\n",
    "\n",
    "        dense_test = sep_test_data.todense().A.astype(np.float32)\n",
    "\n",
    "        for depth in tqdm.tqdm(depth_list):\n",
    "            start_test_preproc = datetime.datetime.now()\n",
    "            test_preprocessed = preprocess_data(dense_test, sep_test_target, mode=\"test\", boosting=gb_type)\n",
    "\n",
    "            finish_test_preproc = datetime.datetime.now()\n",
    "            preprocessing_delta = finish_test_preproc - start_test_preproc\n",
    "            preprocessing_delta = preprocessing_delta.total_seconds()\n",
    "\n",
    "            params = create_parameters(base_params, boosting=gb_type, depth=depth)\n",
    "            params[\"depth\"] = depth\n",
    "            fname = create_path(gb_type, params)\n",
    "            if os.path.exists(fname):\n",
    "                print(\"model exist\")\n",
    "                bst = load_model(fname, boosting=gb_type)\n",
    "            else:\n",
    "                print(\"model is training\")\n",
    "                start_train = datetime.datetime.now()\n",
    "                bst = train(train_preprocessed, params, num_iters=num_iters, boosting=gb_type)\n",
    "                finish_train = datetime.datetime.now()\n",
    "                delta_train = finish_train - start_train\n",
    "                delta_train = int(delta_train.total_seconds() * 1000)\n",
    "                bst.save_model(fname)\n",
    "\n",
    "            start_time = datetime.datetime.now()\n",
    "            preds = predict_shap(bst, test_preprocessed, boosting=gb_type)\n",
    "            assert preds.shape == (sep_test_data.shape[0], sep_test_data.shape[1] + 1)\n",
    "            finish_time = datetime.datetime.now()\n",
    "\n",
    "            delta = finish_time - start_time\n",
    "            delta = delta.total_seconds()\n",
    "\n",
    "            current_res = {\n",
    "                \"preprocessing_time\": preprocessing_delta,\n",
    "                \"boosting\": gb_type,\n",
    "                \"test_size\": size_test,\n",
    "                \"depth\": depth,\n",
    "                \"time\": delta,\n",
    "            }\n",
    "\n",
    "            result.append(current_res)\n",
    "\n",
    "        print(\"*\" * 40)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "result_df = pd.DataFrame(result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "result_df.to_csv(f\"shap_benchmark_{max_bin}_max_bin_with_test_sizes.csv\", index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>boosting</th>\n",
       "      <th>catboost</th>\n",
       "      <th>lightgbm</th>\n",
       "      <th>xgboost</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>test_size</th>\n",
       "      <th>depth</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th rowspan=\"5\" valign=\"top\">1000</th>\n",
       "      <th>2</th>\n",
       "      <td>0.311027</td>\n",
       "      <td>0.090156</td>\n",
       "      <td>0.112515</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0.281931</td>\n",
       "      <td>0.578531</td>\n",
       "      <td>0.300671</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>0.464603</td>\n",
       "      <td>4.159926</td>\n",
       "      <td>1.468442</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>4.918599</td>\n",
       "      <td>23.844245</td>\n",
       "      <td>7.847191</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>93.152000</td>\n",
       "      <td>119.527824</td>\n",
       "      <td>30.872254</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"5\" valign=\"top\">5000</th>\n",
       "      <th>2</th>\n",
       "      <td>1.171963</td>\n",
       "      <td>0.284673</td>\n",
       "      <td>0.241316</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1.081119</td>\n",
       "      <td>2.094985</td>\n",
       "      <td>0.931881</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>1.319114</td>\n",
       "      <td>20.624486</td>\n",
       "      <td>6.498283</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>5.807985</td>\n",
       "      <td>118.552238</td>\n",
       "      <td>38.992395</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>95.049909</td>\n",
       "      <td>601.251603</td>\n",
       "      <td>153.408904</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"5\" valign=\"top\">10000</th>\n",
       "      <th>2</th>\n",
       "      <td>2.048301</td>\n",
       "      <td>0.621454</td>\n",
       "      <td>0.509722</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>2.263058</td>\n",
       "      <td>4.291201</td>\n",
       "      <td>1.935541</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>2.396371</td>\n",
       "      <td>42.788038</td>\n",
       "      <td>12.981580</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>7.078056</td>\n",
       "      <td>240.614644</td>\n",
       "      <td>77.883250</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>95.680684</td>\n",
       "      <td>1189.685032</td>\n",
       "      <td>306.529277</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "boosting          catboost     lightgbm     xgboost\n",
       "test_size depth                                    \n",
       "1000      2       0.311027     0.090156    0.112515\n",
       "          4       0.281931     0.578531    0.300671\n",
       "          6       0.464603     4.159926    1.468442\n",
       "          8       4.918599    23.844245    7.847191\n",
       "          10     93.152000   119.527824   30.872254\n",
       "5000      2       1.171963     0.284673    0.241316\n",
       "          4       1.081119     2.094985    0.931881\n",
       "          6       1.319114    20.624486    6.498283\n",
       "          8       5.807985   118.552238   38.992395\n",
       "          10     95.049909   601.251603  153.408904\n",
       "10000     2       2.048301     0.621454    0.509722\n",
       "          4       2.263058     4.291201    1.935541\n",
       "          6       2.396371    42.788038   12.981580\n",
       "          8       7.078056   240.614644   77.883250\n",
       "          10     95.680684  1189.685032  306.529277"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "result_df = pd.read_csv(\n",
    "    \"shap_benchmark_128_max_bin_with_test_sizes.csv\",\n",
    ")\n",
    "result_df.pivot_table(index=[\"test_size\", \"depth\"], columns=\"boosting\", values=\"time\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th>boosting</th>\n",
       "      <th>catboost</th>\n",
       "      <th>lightgbm</th>\n",
       "      <th>xgboost</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>test_size</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>1000</th>\n",
       "      <td>0.069569</td>\n",
       "      <td>0.002816</td>\n",
       "      <td>0.011025</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5000</th>\n",
       "      <td>0.349831</td>\n",
       "      <td>0.000006</td>\n",
       "      <td>0.047836</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10000</th>\n",
       "      <td>0.770179</td>\n",
       "      <td>0.000006</td>\n",
       "      <td>0.089032</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "boosting   catboost  lightgbm   xgboost\n",
       "test_size                              \n",
       "1000       0.069569  0.002816  0.011025\n",
       "5000       0.349831  0.000006  0.047836\n",
       "10000      0.770179  0.000006  0.089032"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "result_df.pivot_table(index=\"test_size\", columns=\"boosting\", values=\"preprocessing_time\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
